"""
Apply spell-checking, tokenization
Add some meta-information to each utterance like if the target is referred by its class name etc.
Output is a .csv that can be used to do training, analysis etc. of the ReferIt3D data.
"""
import argparse
import pprint

import nltk
import pandas as pd
from symspellpy.symspellpy import SymSpell

from referit3d.in_out.nr3d import load_nr3d_raw_data
from referit3d.in_out.sr3d import load_sr3d_raw_data
from referit3d.in_out.neural_net_oriented import load_scan_related_data
from referit3d.utils import read_lines

from referit3d.data_generation.nr3d.tokenization import pre_process_text
from referit3d.data_generation.nr3d.tokenization import sentence_spelling_dictionary
from referit3d.data_generation.nr3d.tokenization import token_spelling_dictionary

from referit3d.analysis.utterances import mentions_target_class


def parse_arguments():
    parser = argparse.ArgumentParser(description='Prepare Referntial Data For ReferIt3DNet')

    parser.add_argument('-scannet-file', type=str, required=True,
                        help='pkl file containing the data of Scannet as generated by running prepare_scannet_data')
    parser.add_argument('-type', type=str, choices=['nr3d', 'sr3d'], required=True)
    parser.add_argument('-out-file', type=str, required=True, help='the preprocessed data save path')
    parser.add_argument('--nr3d-file', type=str, required=True)
    parser.add_argument('--sr3d-file', type=str, required=True)

    args = parser.parse_args()

    # Print them nicely
    args_string = pprint.pformat(vars(args))
    print(args_string)

    args.__setattr__('word_freq_file', '../../data/language/symspell_frequency_dictionary_en_82_765.txt')
    args.__setattr__('vocab_file', '../../data/language/glove.6B.100d.vocabulary.txt')

    return args


def main():
    # Parse arguments
    args = parse_arguments()

    # Read-scans
    all_scans_in_dict, scenes_split, class_to_idx = load_scan_related_data(args.scannet_file)

    # Read the vocab (used in preprocessing nr3d and sanity checks
    vocab = set(read_lines(args.vocab_file))

    if args.type == 'nr3d':
        raw_data = load_nr3d_raw_data(args.nr3d_file)
        print('loaded Nr3D utterances:', len(raw_data))
        print('Success Rate utterances:', raw_data.correct_guess.mean())

        # Create a speller
        speller = SymSpell()
        print('Spell check Loaded:', speller.load_dictionary(args.word_freq_file, term_index=0, count_index=1))

        # Preprocess the text
        clean_text, tokens, spelled_tokens, corrected, missed_words = pre_process_text(raw_data.utterance,
                                                                                       sentence_spelling_dictionary,
                                                                                       token_spelling_dictionary,
                                                                                       tokenizer=nltk.word_tokenize,
                                                                                       golden_vocabulary=vocab,
                                                                                       token_speller=speller)

        print('After spell-checking, #mised words:', len(missed_words))
        missed_words = list(missed_words)
        for word in sorted(missed_words):
            print(' ', word)

        # Add metadata
        raw_data['tokens'] = spelled_tokens
        processed_data = raw_data.assign(dataset=pd.Series(['nr3d'] * len(raw_data)).values)
        mentions_target_class_fn = lambda x: mentions_target_class(x, all_scans_dict=all_scans_in_dict)
        processed_data['mentions_target_class'] = processed_data.apply(mentions_target_class_fn, axis=1)
        print('mentions-class', processed_data['mentions_target_class'].mean())

    elif args.type == 'sr3d':
        raw_data = load_sr3d_raw_data(args.sr3d_file)

        # Since sr3d has a very simple language. No need for 'complex' tokenization, e.g. spell checking.
        raw_data['tokens'] = raw_data.utterance.apply(lambda x: x.lower())
        basic_punct = '.?!,:;/\-~*_='
        punct_to_space = str.maketrans(basic_punct, ' ' * len(basic_punct))  # map punctuation to space
        raw_data['tokens'] = raw_data['tokens'].apply(lambda x: x.translate(punct_to_space))
        raw_data['tokens'] = raw_data.tokens.apply(lambda x: nltk.word_tokenize(x))

        # Add metadata
        processed_data = raw_data.assign(dataset=pd.Series(['sr3d'] * len(raw_data)).values)
        processed_data = processed_data.assign(
            mentions_target_class=pd.Series([True] * len(processed_data)).values)  # by definition
    else:
        raise ValueError

    # Save the preprocessed data
    processed_data.to_csv(args.out_file, index=False)


if __name__ == '__main__':
    main()
